import os import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import sys sys.path.append('path to file storge d2lzh1981') import d2l
defmasked_softmax(X, valid_length): # X: 3-D tensor, valid_length: 1-D or 2-D tensor softmax = nn.Softmax(dim=-1) if valid_length isNone: return softmax(X) else: shape = X.shape if valid_length.dim() == 1: try: valid_length = torch.FloatTensor(valid_length.numpy().repeat(shape[1], axis=0))#[2,2,3,3] except: valid_length = torch.FloatTensor(valid_length.cpu().numpy().repeat(shape[1], axis=0))#[2,2,3,3] else: valid_length = valid_length.reshape((-1,)) # fill masked elements with a large negative, whose exp is 0 X = SequenceMask(X.reshape((-1, shape[-1])), valid_length) return softmax(X).reshape(shape)
# Save to the d2l package. classDotProductAttention(nn.Module): def__init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout)
# query: (batch_size, #queries, d) # key: (batch_size, #kv_pairs, d) # value: (batch_size, #kv_pairs, dim_v) # valid_length: either (batch_size, ) or (batch_size, xx) defforward(self, query, key, value, valid_length=None): d = query.shape[-1] # set transpose_b=True to swap the last two dimensions of key scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d) attention_weights = self.dropout(masked_softmax(scores, valid_length)) return torch.bmm(attention_weights, value)
deftranspose_qkv(X, num_heads): # Original X shape: (batch_size, seq_len, hidden_size * num_heads), # -1 means inferring its value, after first reshape, X shape: # (batch_size, seq_len, num_heads, hidden_size) X = X.view(X.shape[0], X.shape[1], num_heads, -1) # After transpose, X shape: (batch_size, num_heads, seq_len, hidden_size) X = X.transpose(2, 1).contiguous()
# Merge the first two dimensions. Use reverse=True to infer shape from # right to left. # output shape: (batch_size * num_heads, seq_len, hidden_size) output = X.view(-1, X.shape[2], X.shape[3]) return output
# Saved in the d2l package for later use deftranspose_output(X, num_heads): # A reversed version of transpose_qkv X = X.view(-1, num_heads, X.shape[1], X.shape[2]) X = X.transpose(2, 1).contiguous() return X.view(X.shape[0], X.shape[1], -1)
import numpy as np pe = PositionalEncoding(20, 0) Y = pe(torch.zeros((1, 100, 20))).numpy() d2l.plot(np.arange(100), Y[0, :, 4:8].T, figsize=(6, 2.5), legend=["dim %d" % p for p in [4, 5, 6, 7]])
$$ Result-to-Test $$
编码器
有了组成Transformer的各个模块,可以搭建一个编码器。
编码器包含一个多头注意力层,一个position-wise FFN,和两个 Add and Norm层。
defforward(self, X, valid_length, *args): X = self.pos_encoding(self.embed(X) * math.sqrt(self.embedding_size)) for blk in self.blks: X = blk(X, valid_length) return X
defforward(self, X, state): X = self.pos_encoding(self.embed(X) * math.sqrt(self.embedding_size)) for blk in self.blks: X, state = blk(X, state) return self.dense(X), state
import zipfile import torch import requests from io import BytesIO from torch.utils import data import sys import collections
classVocab(object): def__init__(self, tokens, min_freq=0, use_special_tokens=False): # sort by frequency and token counter = collections.Counter(tokens) token_freqs = sorted(counter.items(), key=lambda x: x[0]) token_freqs.sort(key=lambda x: x[1], reverse=True) if use_special_tokens: # padding, begin of sentence, end of sentence, unknown self.pad, self.bos, self.eos, self.unk = (0, 1, 2, 3) tokens = ['', '', '', ''] else: self.unk = 0 tokens = [''] tokens += [token for token, freq in token_freqs if freq >= min_freq] self.idx_to_token = [] self.token_to_idx = dict() for token in tokens: self.idx_to_token.append(token) self.token_to_idx[token] = len(self.idx_to_token) - 1 def__len__(self): return len(self.idx_to_token) def__getitem__(self, tokens): ifnot isinstance(tokens, (list, tuple)): return self.token_to_idx.get(tokens, self.unk) else: return [self.__getitem__(token) for token in tokens] defto_tokens(self, indices): ifnot isinstance(indices, (list, tuple)): return self.idx_to_token[indices] else: return [self.idx_to_token[index] for index in indices]
defload_data_nmt(batch_size, max_len, num_examples=1000): """Download an NMT dataset, return its vocabulary and data iterator.""" # Download and preprocess defpreprocess_raw(text): text = text.replace('\u202f', ' ').replace('\xa0', ' ') out = '' for i, char in enumerate(text.lower()): if char in (',', '!', '.') and text[i-1] != ' ': out += ' ' out += char return out
with open('path to data dile', 'r') as f: raw_text = f.read()
text = preprocess_raw(raw_text)
# Tokenize source, target = [], [] for i, line in enumerate(text.split('\n')): if i >= num_examples: break parts = line.split('\t') if len(parts) >= 2: source.append(parts[0].split(' ')) target.append(parts[1].split(' '))
# Build vocab defbuild_vocab(tokens): tokens = [token for line in tokens for token in line] return Vocab(tokens, min_freq=3, use_special_tokens=True) src_vocab, tgt_vocab = build_vocab(source), build_vocab(target)
# Convert to index arrays defpad(line, max_len, padding_token): if len(line) > max_len: return line[:max_len] return line + [padding_token] * (max_len - len(line))
defbuild_array(lines, vocab, max_len, is_source): lines = [vocab[line] for line in lines] ifnot is_source: lines = [[vocab.bos] + line + [vocab.eos] for line in lines] array = torch.tensor([pad(line, max_len, vocab.pad) for line in lines]) valid_len = (array != vocab.pad).sum(1) return array, valid_len